{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BWBYMn_h8GFv"
      },
      "source": [
        "Note: The evals here have been run on GPU so they may not exactly match the results reported in the paper which were run on TPUs, however the difference in accuracy should not be more than 0.1%."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8RXaPu45gJmX"
      },
      "source": [
        "# Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "executionInfo": {
          "elapsed": 97,
          "status": "ok",
          "timestamp": 1605000805970,
          "user": {
            "displayName": "Saurabh Saxena",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgV8J9J5asAJLihDECu1FLcPeJvRZ5oxWSyN2BFEA=s64",
            "userId": "09677429467784970077"
          },
          "user_tz": 480
        },
        "id": "bpOA40tNsLH_"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "executionInfo": {
          "elapsed": 85,
          "status": "ok",
          "timestamp": 1605000805970,
          "user": {
            "displayName": "Saurabh Saxena",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgV8J9J5asAJLihDECu1FLcPeJvRZ5oxWSyN2BFEA=s64",
            "userId": "09677429467784970077"
          },
          "user_tz": 480
        },
        "id": "bZUxod8VxQdD"
      },
      "outputs": [],
      "source": [
        "CROP_PROPORTION = 0.875  # Standard for ImageNet.\n",
        "HEIGHT = 224\n",
        "WIDTH = 224\n",
        "\n",
        "def _compute_crop_shape(\n",
        "    image_height, image_width, aspect_ratio, crop_proportion):\n",
        "  \"\"\"Compute aspect ratio-preserving shape for central crop.\n",
        "\n",
        "  The resulting shape retains `crop_proportion` along one side and a proportion\n",
        "  less than or equal to `crop_proportion` along the other side.\n",
        "\n",
        "  Args:\n",
        "    image_height: Height of image to be cropped.\n",
        "    image_width: Width of image to be cropped.\n",
        "    aspect_ratio: Desired aspect ratio (width / height) of output.\n",
        "    crop_proportion: Proportion of image to retain along the less-cropped side.\n",
        "\n",
        "  Returns:\n",
        "    crop_height: Height of image after cropping.\n",
        "    crop_width: Width of image after cropping.\n",
        "  \"\"\"\n",
        "  image_width_float = tf.cast(image_width, tf.float32)\n",
        "  image_height_float = tf.cast(image_height, tf.float32)\n",
        "\n",
        "  def _requested_aspect_ratio_wider_than_image():\n",
        "    crop_height = tf.cast(tf.math.rint(\n",
        "        crop_proportion / aspect_ratio * image_width_float), tf.int32)\n",
        "    crop_width = tf.cast(tf.math.rint(\n",
        "        crop_proportion * image_width_float), tf.int32)\n",
        "    return crop_height, crop_width\n",
        "\n",
        "  def _image_wider_than_requested_aspect_ratio():\n",
        "    crop_height = tf.cast(\n",
        "        tf.math.rint(crop_proportion * image_height_float), tf.int32)\n",
        "    crop_width = tf.cast(tf.math.rint(\n",
        "        crop_proportion * aspect_ratio *\n",
        "        image_height_float), tf.int32)\n",
        "    return crop_height, crop_width\n",
        "\n",
        "  return tf.cond(\n",
        "      aspect_ratio \u003e image_width_float / image_height_float,\n",
        "      _requested_aspect_ratio_wider_than_image,\n",
        "      _image_wider_than_requested_aspect_ratio)\n",
        "\n",
        "\n",
        "def center_crop(image, height, width, crop_proportion):\n",
        "  \"\"\"Crops to center of image and rescales to desired size.\n",
        "\n",
        "  Args:\n",
        "    image: Image Tensor to crop.\n",
        "    height: Height of image to be cropped.\n",
        "    width: Width of image to be cropped.\n",
        "    crop_proportion: Proportion of image to retain along the less-cropped side.\n",
        "\n",
        "  Returns:\n",
        "    A `height` x `width` x channels Tensor holding a central crop of `image`.\n",
        "  \"\"\"\n",
        "  shape = tf.shape(image)\n",
        "  image_height = shape[0]\n",
        "  image_width = shape[1]\n",
        "  crop_height, crop_width = _compute_crop_shape(\n",
        "      image_height, image_width, height / width, crop_proportion)\n",
        "  offset_height = ((image_height - crop_height) + 1) // 2\n",
        "  offset_width = ((image_width - crop_width) + 1) // 2\n",
        "  image = tf.image.crop_to_bounding_box(\n",
        "      image, offset_height, offset_width, crop_height, crop_width)\n",
        "  image = tf.image.resize(image, [height, width],\n",
        "                          method=tf.image.ResizeMethod.BICUBIC)\n",
        "  return image\n",
        "\n",
        "def preprocess_for_eval(image, height, width):\n",
        "  \"\"\"Preprocesses the given image for evaluation.\n",
        "\n",
        "  Args:\n",
        "    image: `Tensor` representing an image of arbitrary size.\n",
        "    height: Height of output image.\n",
        "    width: Width of output image.\n",
        "\n",
        "  Returns:\n",
        "    A preprocessed image `Tensor`.\n",
        "  \"\"\"\n",
        "  image = center_crop(image, height, width, crop_proportion=CROP_PROPORTION)\n",
        "  image = tf.reshape(image, [height, width, 3])\n",
        "  image = tf.clip_by_value(image, 0., 1.)\n",
        "  return image\n",
        "\n",
        "def preprocess_image(features):\n",
        "  \"\"\"Preprocesses the given image.\n",
        "\n",
        "  Args:\n",
        "    image: `Tensor` representing an image of arbitrary size.\n",
        "\n",
        "  Returns:\n",
        "    A preprocessed image `Tensor` of range [0, 1].\n",
        "  \"\"\"\n",
        "  image = features[\"image\"]\n",
        "  image = tf.image.convert_image_dtype(image, dtype=tf.float32)\n",
        "  image = preprocess_for_eval(image, HEIGHT, WIDTH)\n",
        "  features[\"image\"] = image\n",
        "  return features"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nAmEFj9V3qMI"
      },
      "source": [
        "Load dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "executionInfo": {
          "elapsed": 82,
          "status": "ok",
          "timestamp": 1605000805971,
          "user": {
            "displayName": "Saurabh Saxena",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgV8J9J5asAJLihDECu1FLcPeJvRZ5oxWSyN2BFEA=s64",
            "userId": "09677429467784970077"
          },
          "user_tz": 480
        },
        "id": "xapZFsX5sa4y"
      },
      "outputs": [],
      "source": [
        "BATCH_SIZE = 50\n",
        "ds = tfds.load(name='imagenet2012', split='validation').map(preprocess_image).batch(BATCH_SIZE).prefetch(1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "executionInfo": {
          "elapsed": 75,
          "status": "ok",
          "timestamp": 1605000805971,
          "user": {
            "displayName": "Saurabh Saxena",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgV8J9J5asAJLihDECu1FLcPeJvRZ5oxWSyN2BFEA=s64",
            "userId": "09677429467784970077"
          },
          "user_tz": 480
        },
        "id": "2KeU_Qy48mCA"
      },
      "outputs": [],
      "source": [
        "def eval(model_path, log=False):\n",
        "  if log:\n",
        "    print(\"Loading model from %s\" % model_path)\n",
        "  model = tf.saved_model.load(model_path)\n",
        "  if log:\n",
        "    print(\"Loaded model!\")\n",
        "  top_1_accuracy = tf.keras.metrics.Accuracy('top_1_accuracy')\n",
        "  for i, features in enumerate(ds):\n",
        "    logits = model(features[\"image\"], trainable=False)[\"logits_sup\"]\n",
        "    top_1_accuracy.update_state(features[\"label\"], tf.argmax(logits, axis=-1))\n",
        "    if log and (i + 1) % 50 == 0:\n",
        "      print(\"Finished %d examples\" % ((i + 1) * BATCH_SIZE))\n",
        "  return top_1_accuracy.result().numpy().astype(float)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pGalUcFNE7Ee"
      },
      "source": [
        "# SimCLR v2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c1KYQede2Hh3"
      },
      "source": [
        "Finetuned models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "executionInfo": {
          "elapsed": 70,
          "status": "ok",
          "timestamp": 1605000805971,
          "user": {
            "displayName": "Saurabh Saxena",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgV8J9J5asAJLihDECu1FLcPeJvRZ5oxWSyN2BFEA=s64",
            "userId": "09677429467784970077"
          },
          "user_tz": 480
        },
        "id": "2HVjY0P9o32d",
        "outputId": "47160ad0-0fbe-43c2-957a-5829ec78c5f5"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r50_1x_sk0/saved_model/\n",
            "Top-1: 58.0\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r50_1x_sk0/saved_model/\n",
            "Top-1: 68.4\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r50_1x_sk0/saved_model/\n",
            "Top-1: 76.3\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r50_1x_sk1/saved_model/\n",
            "Top-1: 64.5\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r50_1x_sk1/saved_model/\n",
            "Top-1: 72.0\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r50_1x_sk1/saved_model/\n",
            "Top-1: 78.6\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r50_2x_sk0/saved_model/\n",
            "Top-1: 66.2\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r50_2x_sk0/saved_model/\n",
            "Top-1: 73.9\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r50_2x_sk0/saved_model/\n",
            "Top-1: 79.1\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r50_2x_sk1/saved_model/\n",
            "Top-1: 70.7\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r50_2x_sk1/saved_model/\n",
            "Top-1: 77.0\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r50_2x_sk1/saved_model/\n",
            "Top-1: 81.3\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r101_1x_sk0/saved_model/\n",
            "Top-1: 62.1\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r101_1x_sk0/saved_model/\n",
            "Top-1: 71.5\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r101_1x_sk0/saved_model/\n",
            "Top-1: 78.2\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r101_1x_sk1/saved_model/\n",
            "Top-1: 68.3\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r101_1x_sk1/saved_model/\n",
            "Top-1: 75.1\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r101_1x_sk1/saved_model/\n",
            "Top-1: 80.7\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r101_2x_sk0/saved_model/\n",
            "Top-1: 69.1\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r101_2x_sk0/saved_model/\n",
            "Top-1: 75.9\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r101_2x_sk0/saved_model/\n",
            "Top-1: 80.8\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r101_2x_sk1/saved_model/\n",
            "Top-1: 73.3\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r101_2x_sk1/saved_model/\n",
            "Top-1: 78.8\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r101_2x_sk1/saved_model/\n",
            "Top-1: 82.4\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r152_1x_sk0/saved_model/\n",
            "Top-1: 64.1\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r152_1x_sk0/saved_model/\n",
            "Top-1: 73.0\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r152_1x_sk0/saved_model/\n",
            "Top-1: 79.3\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r152_1x_sk1/saved_model/\n",
            "Top-1: 70.0\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r152_1x_sk1/saved_model/\n",
            "Top-1: 76.4\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r152_1x_sk1/saved_model/\n",
            "Top-1: 81.3\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r152_2x_sk0/saved_model/\n",
            "Top-1: 70.1\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r152_2x_sk0/saved_model/\n",
            "Top-1: 76.6\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r152_2x_sk0/saved_model/\n",
            "Top-1: 81.1\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r152_2x_sk1/saved_model/\n",
            "Top-1: 74.2\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r152_2x_sk1/saved_model/\n",
            "Top-1: 79.4\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r152_2x_sk1/saved_model/\n",
            "Top-1: 82.9\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_1pct/r152_3x_sk1/saved_model/\n",
            "Top-1: 74.9\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_10pct/r152_3x_sk1/saved_model/\n",
            "Top-1: 80.1\n",
            "gs://simclr-checkpoints-tf2/simclrv2/finetuned_100pct/r152_3x_sk1/saved_model/\n",
            "Top-1: 83.1\n"
          ]
        }
      ],
      "source": [
        "path_pat = \"gs://simclr-checkpoints-tf2/simclrv2/finetuned_{pct}pct/r{depth}_{width_multiplier}x_sk{sk}/saved_model/\"\n",
        "results = {}\n",
        "\n",
        "for resnet_depth in (50, 101, 152):\n",
        "  for width_multiplier in (1, 2):\n",
        "    for sk in (0, 1):\n",
        "      for pct in (1, 10, 100):\n",
        "        path = path_pat.format(pct=pct, depth=resnet_depth, width_multiplier=width_multiplier, sk=sk)\n",
        "        results[path] = eval(path)\n",
        "        print(path)\n",
        "        print(\"Top-1: %.1f\" % (results[path] * 100))\n",
        "\n",
        "resnet_depth = 152\n",
        "width_multiplier = 3\n",
        "sk = 1\n",
        "for pct in (1, 10, 100):\n",
        "  path = path_pat.format(pct=pct, depth=resnet_depth, width_multiplier=width_multiplier, sk=sk)\n",
        "  results[path] = eval(path)\n",
        "  print(path)\n",
        "  print(\"Top-1: %.1f\" % (results[path] * 100))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rb85INgi5t7X"
      },
      "source": [
        "Supervised"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "executionInfo": {
          "elapsed": 11639325,
          "status": "ok",
          "timestamp": 1605042406803,
          "user": {
            "displayName": "Saurabh Saxena",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgV8J9J5asAJLihDECu1FLcPeJvRZ5oxWSyN2BFEA=s64",
            "userId": "09677429467784970077"
          },
          "user_tz": 480
        },
        "id": "rsCnL0ISzF8h",
        "outputId": "791b69ef-a9ae-4ff4-900f-167a50cac814"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "gs://simclr-checkpoints-tf2/simclrv2/supervised/r50_1x_sk0/saved_model/\n",
            "Top-1: 76.6\n",
            "gs://simclr-checkpoints-tf2/simclrv2/supervised/r50_1x_sk1/saved_model/\n",
            "Top-1: 78.5\n",
            "gs://simclr-checkpoints-tf2/simclrv2/supervised/r50_2x_sk0/saved_model/\n",
            "Top-1: 77.8\n",
            "gs://simclr-checkpoints-tf2/simclrv2/supervised/r50_2x_sk1/saved_model/\n",
            "Top-1: 79.3\n",
            "gs://simclr-checkpoints-tf2/simclrv2/supervised/r101_1x_sk0/saved_model/\n",
            "Top-1: 78.0\n",
            "gs://simclr-checkpoints-tf2/simclrv2/supervised/r101_1x_sk1/saved_model/\n",
            "Top-1: 79.6\n",
            "gs://simclr-checkpoints-tf2/simclrv2/supervised/r101_2x_sk0/saved_model/\n",
            "Top-1: 78.8\n",
            "gs://simclr-checkpoints-tf2/simclrv2/supervised/r101_2x_sk1/saved_model/\n",
            "Top-1: 80.1\n",
            "gs://simclr-checkpoints-tf2/simclrv2/supervised/r152_1x_sk0/saved_model/\n",
            "Top-1: 78.2\n",
            "gs://simclr-checkpoints-tf2/simclrv2/supervised/r152_1x_sk1/saved_model/\n",
            "Top-1: 80.0\n",
            "gs://simclr-checkpoints-tf2/simclrv2/supervised/r152_2x_sk0/saved_model/\n",
            "Top-1: 79.1\n",
            "gs://simclr-checkpoints-tf2/simclrv2/supervised/r152_2x_sk1/saved_model/\n",
            "Top-1: 80.4\n",
            "gs://simclr-checkpoints-tf2/simclrv2/supervised/r152_3x_sk1/saved_model/\n",
            "Top-1: 80.5\n"
          ]
        }
      ],
      "source": [
        "path_pat = \"gs://simclr-checkpoints-tf2/simclrv2/supervised/r{depth}_{width_multiplier}x_sk{sk}/saved_model/\"\n",
        "results = {}\n",
        "\n",
        "for resnet_depth in (50, 101, 152):\n",
        "  for width_multiplier in (1, 2):\n",
        "    for sk in (0, 1):\n",
        "      path = path_pat.format(depth=resnet_depth, width_multiplier=width_multiplier, sk=sk)\n",
        "      results[path] = eval(path)\n",
        "      print(path)\n",
        "      print(\"Top-1: %.1f\" % (results[path] * 100))\n",
        "\n",
        "resnet_depth = 152\n",
        "width_multiplier = 3\n",
        "sk = 1\n",
        "path = path_pat.format(depth=resnet_depth, width_multiplier=width_multiplier, sk=sk)\n",
        "results[path] = eval(path)\n",
        "print(path)\n",
        "print(\"Top-1: %.1f\" % (results[path] * 100))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vT78TaUN6BX7"
      },
      "source": [
        "Pretrained with linear eval"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "executionInfo": {
          "elapsed": 510913,
          "status": "ok",
          "timestamp": 1605001316821,
          "user": {
            "displayName": "Saurabh Saxena",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgV8J9J5asAJLihDECu1FLcPeJvRZ5oxWSyN2BFEA=s64",
            "userId": "09677429467784970077"
          },
          "user_tz": 480
        },
        "id": "r3IbPoYg6A84",
        "outputId": "7f9dfbd8-8bae-4436-e1aa-3e124df92ff1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_1x_sk0/saved_model/\n",
            "Top-1: 71.7\n",
            "gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_1x_sk1/saved_model/\n",
            "Top-1: 74.6\n",
            "gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_2x_sk0/saved_model/\n",
            "Top-1: 75.4\n",
            "gs://simclr-checkpoints-tf2/simclrv2/pretrained/r50_2x_sk1/saved_model/\n",
            "Top-1: 77.8\n",
            "gs://simclr-checkpoints-tf2/simclrv2/pretrained/r101_1x_sk0/saved_model/\n",
            "Top-1: 73.7\n",
            "gs://simclr-checkpoints-tf2/simclrv2/pretrained/r101_1x_sk1/saved_model/\n",
            "Top-1: 76.3\n",
            "gs://simclr-checkpoints-tf2/simclrv2/pretrained/r101_2x_sk0/saved_model/\n",
            "Top-1: 77.0\n",
            "gs://simclr-checkpoints-tf2/simclrv2/pretrained/r101_2x_sk1/saved_model/\n",
            "Top-1: 79.1\n",
            "gs://simclr-checkpoints-tf2/simclrv2/pretrained/r152_1x_sk0/saved_model/\n",
            "Top-1: 74.6\n",
            "gs://simclr-checkpoints-tf2/simclrv2/pretrained/r152_1x_sk1/saved_model/\n",
            "Top-1: 77.3\n",
            "gs://simclr-checkpoints-tf2/simclrv2/pretrained/r152_2x_sk0/saved_model/\n",
            "Top-1: 77.4\n",
            "gs://simclr-checkpoints-tf2/simclrv2/pretrained/r152_2x_sk1/saved_model/\n",
            "Top-1: 79.3\n",
            "gs://simclr-checkpoints-tf2/simclrv2/pretrained/r152_3x_sk1/saved_model/\n",
            "Top-1: 79.9\n"
          ]
        }
      ],
      "source": [
        "path_pat = \"gs://simclr-checkpoints-tf2/simclrv2/pretrained/r{depth}_{width_multiplier}x_sk{sk}/saved_model/\"\n",
        "results = {}\n",
        "\n",
        "for resnet_depth in (50, 101, 152):\n",
        "  for width_multiplier in (1, 2):\n",
        "    for sk in (0, 1):\n",
        "      path = path_pat.format(depth=resnet_depth, width_multiplier=width_multiplier, sk=sk)\n",
        "      results[path] = eval(path)\n",
        "      print(path)\n",
        "      print(\"Top-1: %.1f\" % (results[path] * 100))\n",
        "\n",
        "resnet_depth = 152\n",
        "width_multiplier = 3\n",
        "sk = 1\n",
        "path = path_pat.format(depth=resnet_depth, width_multiplier=width_multiplier, sk=sk)\n",
        "results[path] = eval(path)\n",
        "print(path)\n",
        "print(\"Top-1: %.1f\" % (results[path] * 100))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dky_MGU7SVks"
      },
      "source": [
        "# SimCLR v1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1sUyTclWSZad"
      },
      "source": [
        "Finetuned"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "executionInfo": {
          "elapsed": 4096533,
          "status": "ok",
          "timestamp": 1605058511622,
          "user": {
            "displayName": "Saurabh Saxena",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgV8J9J5asAJLihDECu1FLcPeJvRZ5oxWSyN2BFEA=s64",
            "userId": "09677429467784970077"
          },
          "user_tz": 480
        },
        "id": "oTJq19MeSZHz",
        "outputId": "d1814bbc-faa4-4b36-d5e3-d310596d2c46"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "gs://simclr-checkpoints-tf2/simclrv1/finetune_10pct/1x/saved_model/\n",
            "Top-1: 65.8\n",
            "gs://simclr-checkpoints-tf2/simclrv1/finetune_10pct/2x/saved_model/\n",
            "Top-1: 71.6\n",
            "gs://simclr-checkpoints-tf2/simclrv1/finetune_10pct/4x/saved_model/\n",
            "Top-1: 74.5\n",
            "gs://simclr-checkpoints-tf2/simclrv1/finetune_100pct/1x/saved_model/\n",
            "Top-1: 75.6\n",
            "gs://simclr-checkpoints-tf2/simclrv1/finetune_100pct/2x/saved_model/\n",
            "Top-1: 79.2\n",
            "gs://simclr-checkpoints-tf2/simclrv1/finetune_100pct/4x/saved_model/\n",
            "Top-1: 80.8\n"
          ]
        }
      ],
      "source": [
        "path_pat = \"gs://simclr-checkpoints-tf2/simclrv1/finetune_{pct}pct/{width_multiplier}x/saved_model/\"\n",
        "results = {}\n",
        "\n",
        "resnet_depth = 50\n",
        "for pct in (10, 100):\n",
        "  for width_multiplier in (1, 2, 4):\n",
        "    path = path_pat.format(pct=pct, width_multiplier=width_multiplier)\n",
        "    results[path] = eval(path)\n",
        "    print(path)\n",
        "    print(\"Top-1: %.1f\" % (results[path] * 100))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fcNVhsLfIfT1"
      },
      "source": [
        "Pretrained with linear eval"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "executionInfo": {
          "elapsed": 2123208,
          "status": "ok",
          "timestamp": 1605047219252,
          "user": {
            "displayName": "Saurabh Saxena",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgV8J9J5asAJLihDECu1FLcPeJvRZ5oxWSyN2BFEA=s64",
            "userId": "09677429467784970077"
          },
          "user_tz": 480
        },
        "id": "PQHzzefZS397",
        "outputId": "39c2fcd3-dc86-4a8a-d618-c4a9609d91ec"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "gs://simclr-checkpoints-tf2/simclrv1/pretrain/1x/saved_model/\n",
            "Top-1: 69.0\n",
            "gs://simclr-checkpoints-tf2/simclrv1/pretrain/2x/saved_model/\n",
            "Top-1: 74.2\n",
            "gs://simclr-checkpoints-tf2/simclrv1/pretrain/4x/saved_model/\n",
            "Top-1: 76.6\n"
          ]
        }
      ],
      "source": [
        "path_pat = \"gs://simclr-checkpoints-tf2/simclrv1/pretrain/{width_multiplier}x/saved_model/\"\n",
        "results = {}\n",
        "\n",
        "resnet_depth = 50\n",
        "for width_multiplier in (1, 2, 4):\n",
        "  path = path_pat.format(width_multiplier=width_multiplier)\n",
        "  results[path] = eval(path)\n",
        "  print(path)\n",
        "  print(\"Top-1: %.1f\" % (results[path] * 100))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "name": "SimCLR results for imagenet",
      "provenance": [
        {
          "file_id": "1Pg2RBI-SiywuKo6RjQJiQuuhuyzZn_Fm",
          "timestamp": 1605062386386
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
